---
layout: post
date: 2023-05-02
title: "SIMD with Zig"
description: "An introduction to using SIMD with Zig"
tags: [zig]
---

<p>To find the index of the first instance of a character within a body of [ASCII] text, you might write something like:</p>

{% highlight zig %}
fn indexOf(haystack: []const u8, needle: u8) ?usize {
  for (haystack, 0..) |c, i| {
    if (c == needle) return i;
  }
  return null;
}
{% endhighlight %}

<p>Or use the <code>std.mem.indexOfScalar</code> function from the standard library, which is essentially the same implementation. This implementation loops through the input and checks each character, one by one, to see if it's equal to our target.</p>

<p>With SIMD, we can leverage CPU instructions to check multiple characters of our input in parallel. Let's do that in Zig.</p>

<p>To keep this simple for now, let's pretend that our input is always exactly 8 characters long (we'll look at dynamic input lengths after, but 8 characters means we can illustrate the full content of our vectors).</p>

<p>As an example, say we have "Hello Jo" and we want the first index of "o" (which is 4). Our first step is to create a vector (think of it as an array) of 8 elements, each containing the value "o":</p>

{% highlight zig %}
const vector_len = 8;
const vector_needles: @Vector(vector_len, u8) = @splat(@as(u8, 'o'));
{% endhighlight %}

<p>Our <code>@as</code> cast is a very Zig-specicific thing. <code>'o'</code> on its own is a <code>comptime_int</code>. If we try to use that, we'll get an error message: <em>expected integer, float, bool, or pointer for the vector element type</em>. So we coerce the type to a <code>u8</code>. The more important part of the above code is the <code>@splat</code> builtin, which creates a vector with each element containing the specified value. The length of the created vector is inferred from the type we're assigning to, which above is <code>@Vector(vector_len, u8)</code> If we were to print <code>vector_needles</code>, we'd see 'o' eight times (the ASCII value of 'o' is 111):</p>

{% highlight text %}
{ 111, 111, 111, 111, 111, 111, 111, 111 }
{% endhighlight %}

<p>The next step is to take our input and also convert that into a vector. Because we've said that, for the time being, our input will be limited to 8 characters, this is easy:</p>

{% highlight zig %}
const haystack = "Hello Jo";
const vector_haystack: @Vector(vector_len, u8) = haystack.*;
{% endhighlight %}

<p>The <code>@Vector</code> builtin returns a type, and because our <code>haystack</code> is 8 characters long, the same as our <code>vector_len</code> we can store our full haystack into this new vector. If we were to print <code>vector_haystack</code>, we'd get:</p>

{% highlight text %}
{ 72, 101, 108, 108, 111, 32, 74, 111 }
{% endhighlight %}

<p>We now have two vectors, each containing eight u8 (byte) values. Our first SIMD operation will be to compare the two using the equality operator (<code>==</code>):</p>

{% highlight zig %}
const matches = vector_haystack == vector_needles;
{% endhighlight %}

<p>This line of code is powerful largely because it's simple. Vectors can be subjected to any arithmetic or bitwise operations.</p>

<p>This results in a new vector of type <code>@Vector(vector_len, bool)</code>, and its content will be:</p>

{% highlight text %}
{ false, false, false, false, true, false, false, true }
{% endhighlight %}

<p>If you look at this closely, you'll note that we're getting close to our answer. The first <code>true</code> occurs at index 4, which is the correct answer. We've compared our haystack with our needle and gotten the result for each index in parallel. But we still need to extract the index from the above. We could loop through <code>matches</code>, but then we'd be back at iterating and comparing one value at a time.</p>

<p>A quick solution to the above problem is to use <code>std.simd.firstTrue</code>, which will give us <code>4</code>:</code></p>

{% highlight zig %}
const vector_len = 8;
const vector_needles: @Vector(vector_len, u8) = @splat(@as(u8, 'o'));

const haystack = "Hello Jo";
const vector_haystack: @Vector(vector_len, u8) = haystack.*;

const matches = vector_haystack == vector_needles;

const index = std.simd.firstTrue(matches);
{% endhighlight %}

<p>But how does <code>firstTrue</code> work? Let's build our own simple implementation:</p>

<p>The first thing that we'll do is check if we have <strong>any</strong> matches. This is possibly a step that we can avoid, but it'll help us get the ball rolling:</p>

{% highlight zig %}
if (!@reduce(.Or, matches)) {
  return null;
}
// TODO
{% endhighlight %}

<p>The <code>@reduce</code> builtin  takes a vector, applies the operation, and returns a scalar (a single value). Here we're applying the <code>std.builtin.ReduceOp.Or</code> operation on our matches. If any of the values are <code>true</code>, this will return <code>true</code>. If all values are <code>false</code>, this will return <code>false</code>. Other possible operations are: <code>.And</code>, <code>.Or</code>, <code>.Xor</code>, <code>.Min</code>, <code>.Max</code>, <code>.Add</code> and <code>.Mul</code> (some operations are only available for some types of vectors, for example <code>.Add</code> doesn't make sense for a vector of booleans).</p>

<p>If we get past this check, we know that we have at least 1 match (i.e. a <code>true</code>) in our vector. How do we get its index? Admittedly, it's a bit more complicated than you might think. We'll need to use the <code>@select</code> builtin. <code>@select</code> is a bit like a parallel if statement, but if we were to write it using normal code, it might look like:</p>

{% highlight zig %}
fn select(comptime T: type, pred: [8]bool, a: [8]T, b: [8]T) [8]T {
  var out: [8]T = undefined;
  for (pred, 0..) |p, i| {
    out[i] = if (p) a[i] else b[i]
  }
  return out;
}
{% endhighlight %}

<p>It always takes a vector of booleans, and based on the true/false values within this vector, it will select values from either input source <code>a</code> (when <code>true</code>) or input source <code>b</code> (when <code>false</code>).</p>

<p>How does <code>@select</code> help us? First we'll create two new vectors. These will act as our input source <code>a</code> and input source <code>b</code>. The first is a vector that contains the index of each index. Huh? Let's let the code explain:</p>

{% highlight zig %}
const indexes = std.simd.iota(u8, vector_len);
{% endhighlight %}

<p>If we print it out, we get:</p>

{% highlight text %}
{ 0, 1, 2, 3, 4, 5, 6, 7 }
{% endhighlight %}

<p>It's ok if it isn't immediately obvious how that helps. Our next vector will be even weirder:</p>

{% highlight zig %}
const nulls: @Vector(vector_len, u8) = @splat(@as(u8, 255));
{% endhighlight %}

<p>We've seen <code>@splat</code> already, so we know the output to this is:</p>

{% highlight text %}
{ 255, 255, 255, 255, 255, 255, 255, 255 }
{% endhighlight %}

<p> We now have 3 vectors, and we know a little bit about <code>@select</code>. Let's look at those 3 vectors side by side, and see if any pattern emerges:</p>

{% highlight text %}
{ false, false, false, false, true, false, false, true }
{ 0,     1,     2,     3,     4,     5,     6,    7 }
{ 255,   255,   255,   255,   255,   255,   255,  255 }
{% endhighlight %}

<p>I still don't really see any pattern here. But what if pass this into <code>@select</code>:</p>

{% highlight zig %}
const result = @select(u8, matches, indexes, nulls);
{% endhighlight %}

<p>This results in:</p>

{% highlight text %}
{ 255, 255, 255, 255, 4, 255, 255, 7 }
{% endhighlight %}

<p>We'll revisit <code>@select</code> in a second, but I see something I want to keep exploring: the smallest value in this vector is what we're after. We want to extract the <code>4</code> and we already know that <code>@reduce</code> is our tool for turning a vector into a scalar. We also briefly mentioned that one of the operations is <code>.Min</code>. Putting that together, we end up with:</p>

{% highlight zig %}
const index = @reduce(.Min, result);
{% endhighlight %}

<p>Which returns <code>4</code>, our final and correct answer.</p>

<p>There's a few things we still want to explore, but first let's go back and look at <code>@select</code>. We had a vector of booleans and we wanted to get the first index which was <code>true</code>. The first thing we did was create another vector that just contains the indexes. On its own, that didn't help. What we needed was a vector with indexes when our match was true, and some invalid value when our match was false. Only this way could we use <code>@reduce</code> (as it would, "ignore" invalid values from non-matches). For our invalid value, we used <code>255</code>. We couldn't use <code>null</code> (vectors operate on numbers and booleans only) and we couldn't use a negative number since we relied on the <code>.Min</code> operator to extract our first (valid) index. Since we're limiting ourselves to a <code>vector_len</code> of 8, we could have used any value > 8. If our vector was larger, say, 512, we could simply use 513 or any other larger value (we'd have to change the type of our nulls vector from u8 to u16 or something, but it would all still work fine).</p>

<p><code>@select</code> gave us a vector that we could <code>@reduce</code> to get our desired result. But it wasn't immediately obvious, to me at least, how to get there. When I think about "finding the index of the first true", I think about iterating one value at a time and discarding any non-matches. This isn't a suitable mindset when working with vectors. We couldn't discard individual elements - we had to transform them in a way that would let us use a reduce operation.</p>

<h3>Vector Length</h3>
<p>There are two details we need to figure out. The first is: what happens when our input isn't the same size as our vector length? This is easy to solve: we process the input one <code>vector_len</code> at a time.</p>

{% highlight zig %}
fn firstIndexOf(haystack: []const u8, needle: u8) ?usize {
  const vector_len = 8;


  // {111, 111, 111, 111, 111, 111, 111, 111}
  const vector_needles: @Vector(vector_len, u8) = @splat(@as(u8, needle));

  // Because we're implementing our own std.simd.firstTrue
  // we can move the following two vectors, indexes and null
  // outside the loop and re-use them.

  // {0, 1, 2, 3, 4, 5, 6, 7}
  const indexes = std.simd.iota(u8, vector_len);

  // {255, 255, 255, 255, 255, 255, 255, 255}
  const nulls: @Vector(vector_len, u8) = @splat(@as(u8, 255));

  var pos: usize = 0;
  var left = haystack.len;
  while (left > 0) {
    if (left < vector_len) {
      // fallback to a normal scan when our input (or what's left of
      // it is smaller than our vector_len)
      return std.mem.indexOfScalarPos(u8, haystack, pos, needle);
    }

    const h: @Vector(vector_len, u8) = haystack[pos..][0..vector_len].*;
    const matches = h == vector_needles;

    if (@reduce(.Or, matches)) {
      // we found a match, we just need to find its index
      const result = @select(u8, matches, indexes, nulls);

      // we have to add pos to this value, since this is merely
      // the index within this vector_len chunk (e.g. 0-7).
      return @reduce(.Min, result) + pos;
    }

    pos += vector_len;
    left -= vector_len;
  }
  return null;
}
{% endhighlight %}

<p>Every iteration of the loop examines <code>vector_len</code> characters at a time. If there's no match within one iteration, we move to the next <code>vector_len</code>. It's possible that our input is less than <code>vector_len</code>, either originally, or eventually as we iterate through it. When this happens, we fallback to using the standard linearly search.</p>

<p>The second question we must answer is: what vector length should we use? So far, we've been using <code>8</code>, but would larger vectors run faster? One option is to use <code>std.simd.suggestVectorSize</code> which returns a <code>?usize</code>. This function attempts to return the maximum supported vector length of the system, or null if SIMD operations aren't supported. But using the largest possible vector length won't <em>always</em> be ideal. You need to consider your input length. Using a vector length of 512 with inputs that are small (say 128-2048) will result in the fallback to <code>std.mem.indexOfScalar</code> being executed often. One solution to this problem is to use different vector lengths based on the input length. On the flip side, if you're dealing with large inputs (megabytes and beyond), a large vector size would be ideal as you'll spend comparatively little time dealing with small parts of the input.</p>

<h3>Conclusion</h3>
<p>SIMD requires a different way to think about variables and how they are processed. It's also an optimization technique that's up to the developer; there's no compiler or runtime help. You need to benchmark, test and tweak in order to figure out what works best. Benchmarking is particularly important because, unless you're dealing with large data or very hot code, there's a good chance that effort won't yield measurable benefits.</p>

<p>Having said that, Zig exposes a pretty concise API: a few builtins (<code>@splat</code>, <code>@select</code>, <code>@Vector</code> and <code>@reduce</code>) along with functions in <code>std.simd</code>. Once you understand these functions and how they can work together, experimentation is straightforward.</p>
